-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve priority-based ordering for recompute #20117
Closed
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…l = 1 (layer wise recompute)
…pengwa/recompute_in_critical_path
pengwa
added
the
training
issues related to ONNX Runtime training; typically submitted using template
label
Mar 28, 2024
…pengwa/recompute_in_critical_path
…pengwa/recompute_in_critical_path
This reverts commit 3e03d1e.
…or input leaf nodes
Will be split into PRs. |
The first one: #20234 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Improve priority-based ordering for recompute
Critical Path Impact in Priority-Based Topological Sort
Setting and comparing critical path impact is essential in scenarios where all nodes in the priority queue
are of low priority. This comparison helps determine which recompute node to select to unblock the backward
critical path. Consider a scenario where:
- A recompute node subgraph, NodeSubgraph-N, exists within transformer layer N (e.g., NodeSubgraph-5 in
layer 5, NodeSubgraph-3 in layer 3).
- Node-A-IN-5 within NodeSubgraph-5 depends on Node-B-IN-0 within NodeSubgraph-0.
- The priority queue contains nodes from NodeSubgraph-0 to NodeSubgraph-5.
In MemoryOptimizer recompute scenarios, we append nodes starting from NodeSubgraph-5 down to NodeSubgraph-0.
Relying solely on node index for comparison could lead to:
This process can significantly delay the execution of Node-A-IN-5, blocking other NodeSubgraph-5 nodes.
Since NodeSubgraph-5 nodes are on the critical path, triggering their dependencies timely is crucial to
ensure their execution as early as possible, ahead of other layers. This necessity led to the introduction
of critical path impact.
Defining Critical Path Impact
Critical path impact is a metric representing a node's influence on the critical path. It is determined
during MemoryOptimizer's operation as follows:
1) Sort graphs without recompute optimization to establish a baseline topological order.
2) Apply recompute optimization.
3) Identify recompute boundary nodes (recompute nodes not consumed by others).
4) For each boundary node, calculate the minimum topological order of all output nodes.
The minimum value indicates the earliest need for the recompute node's execution.
We assign std::numeric_limits<int64_t>::max() - min_topological_order as the critical path impact.
5) For other recompute nodes, assign the maximum critical path impact of their output nodes.
DEFECTs
The change on priority based topo sort, give up the priority for the recompute node orders. And it need firstly a topo sort without applying mem opt, and the first sort might affect the recompute graph a lot. Here is one example:
Some nodes that are not contributing to YieldOp is treated as a non-forward ops, so it will be scheduled after YieldOp, and is possible be scheduled early, then the dependent recompute graph will be scheduled early too. This is not always most optimized.
Motivation and Context
Mistral models cannot run user recipes when enabling recompute. The root cause is, the execution order of recompute are not correct, making the memory saving very limited.